add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)#2627
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)#2627danielvegamyhre merged 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2627
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5376f65 with merge base 9834869 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
453201b to
b3962a0
Compare
e29fb79 to
fa77af5
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
b3962a0 to
9792e76
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
9792e76 to
fd92301
Compare
| offs: Optional[torch.Tensor] = None, | ||
| block_size: int = 32, | ||
| out_dtype: Optional[torch.dtype] = torch.bfloat16, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
add an emulated flag and assert that it's True until we have a real kernel, to make the intent crystal clear?
| assert sqnr >= min_sqnr, f"sqnr {sqnr} is too low, must be >= {min_sqnr}" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("M", (1024, 4096)) |
There was a problem hiding this comment.
nit: might be good to test one case where MKN are the same, and one where they are all different. If you want to do that and keep # of tests manageable, it would probably be iterating on MKN in one go instead of iterating on each individually.
There was a problem hiding this comment.
Updated to parameterize M,N,K together and test "all same" and "all different" cases
There was a problem hiding this comment.
Ergh my stack-pr got in a weird state, changes didn't go through somehow... let me try again
There was a problem hiding this comment.
Ok it's updated now.
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
fd92301 to
18991a4
Compare
18991a4 to
fd04d1a
Compare
stack-info: PR: #2627, branch: danielvegamyhre/stack/23
fd04d1a to
5376f65
Compare
Stacked PRs:
add differentiable mxfp8 grouped gemm with dynamic quant (forward pass)